# Copyright 2023 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from ast import literal_eval
from collections.abc import Callable
from os import environ
from typing import Optional, Union

import mlrun.auth.utils
import mlrun.utils.helpers
from mlrun.config import is_running_as_api

from .utils import AzureVaultStore, list2dict, logger


class SecretsStore:
    def __init__(self):
        self._secrets = {}
        # Hidden secrets' value must not be serialized. Only the keys can be. These secrets are retrieved externally,
        # for example from Vault, and when adding their source they will be retrieved from the external source.
        self._hidden_sources = []
        self._hidden_secrets = {}

    @classmethod
    def from_list(cls, src_list: list):
        store = cls()
        if src_list and isinstance(src_list, list):
            for src in src_list:
                store.add_source(src["kind"], src.get("source"), src.get("prefix", ""))
        return store

    def to_dict(self, struct):
        pass

    def add_source(self, kind, source="", prefix=""):
        if kind == "inline":
            if isinstance(source, str):
                source = literal_eval(source)
            if not isinstance(source, dict):
                raise ValueError("inline secrets must be of type dict")
            for k, v in source.items():
                self._secrets[prefix + k] = str(v)

        elif kind == "file":
            # Ensure files cannot be open from inside the API
            if is_running_as_api():
                raise RuntimeError(
                    "add_source of kind 'file' is not allowed from the API"
                )
            with open(source) as fp:
                lines = fp.read().splitlines()
                secrets_dict = list2dict(lines)
                for k, v in secrets_dict.items():
                    self._secrets[prefix + k] = str(v)

        elif kind == "env":
            for key in source.split(","):
                k = key.strip()
                self._secrets[prefix + k] = environ.get(k)
        # TODO: Vault: uncomment when vault returns to be relevant
        # elif kind == "vault":
        #     if isinstance(source, str):
        #         source = literal_eval(source)
        #     if not isinstance(source, dict):
        #         raise ValueError("vault secrets must be of type dict")
        #
        #     for key, value in self.vault.get_secrets(
        #         source["secrets"],
        #         user=source.get("user"),
        #         project=source.get("project"),
        #     ).items():
        #         self._hidden_secrets[prefix + key] = value
        #     self._hidden_sources.append({"kind": kind, "source": source})
        elif kind == "azure_vault":
            if isinstance(source, str):
                source = literal_eval(source)
            if not isinstance(source, dict):
                raise ValueError("Azure vault secrets must be of type dict")
            if "name" not in source:
                raise ValueError(
                    "'name' must be provided in the source to define an Azure vault"
                )

            azure_vault = AzureVaultStore(source["name"])
            for key, value in azure_vault.get_secrets(source["secrets"]).items():
                self._hidden_secrets[prefix + key] = value
            self._hidden_sources.append({"kind": kind, "source": source})
        elif kind == "kubernetes":
            if isinstance(source, str):
                source = literal_eval(source)
            if not isinstance(source, list):
                raise ValueError("k8s secrets must be of type list")
            for secret in source:
                env_value = environ.get(self.k8s_env_variable_name_for_secret(secret))
                if env_value:
                    self._hidden_secrets[prefix + secret] = env_value
            self._hidden_sources.append({"kind": kind, "source": source})

    def get(self, key, default=None):
        return (
            self._secrets.get(key)
            or self._hidden_secrets.get(key)
            or environ.get(self.k8s_env_variable_name_for_secret(key))
            or default
        )

    def items(self):
        res = self._secrets.copy()
        if self._hidden_secrets:
            res.update(self._hidden_secrets)
        return res.items()

    def to_serial(self):
        # todo: use encryption
        res = [{"kind": "inline", "source": self._secrets.copy()}]
        if self._hidden_sources:
            for src in self._hidden_sources.copy():
                res.append(src)
        return res

    def has_vault_source(self):
        return any(source["kind"] == "vault" for source in self._hidden_sources)

    def has_azure_vault_source(self):
        return any(source["kind"] == "azure_vault" for source in self._hidden_sources)

    def get_azure_vault_k8s_secret(self):
        for source in self._hidden_sources:
            if source["kind"] == "azure_vault":
                return source["source"].get("k8s_secret", None)

    @staticmethod
    def k8s_env_variable_name_for_secret(secret_name):
        from mlrun.config import config

        return config.secret_stores.kubernetes.env_variable_prefix + secret_name

    def get_k8s_secrets(self):
        for source in self._hidden_sources:
            if source["kind"] == "kubernetes":
                return {
                    secret: self.k8s_env_variable_name_for_secret(secret)
                    for secret in source["source"]
                }
        return None


def get_secret_or_env(
    key: str,
    secret_provider: Union[dict, SecretsStore, Callable, None] = None,
    default: Optional[str] = None,
    prefix: Optional[str] = None,
) -> Optional[str]:
    """Retrieve value of a secret, either from a user-provided secret store, or from environment variables.
    The function will retrieve a secret value, attempting to find it according to the following order:

    1. If `secret_provider` was provided, will attempt to retrieve the secret from it
    2. If an MLRun `SecretsStore` was provided, query it for the secret key
    3. An environment variable with the same key
    4. An MLRun-generated env. variable, mounted from a project secret (to be used in MLRun runtimes)
    5. The default value

    Also supports discovering the value inside any environment variable that contains a JSON-encoded list
    of dicts with fields: {'name': 'KEY', 'value': 'VAL', 'value_from': ...}. This fallback is applied
    after checking normal environment variables and before returning the default.
    Example::

        secrets = {"KEY1": "VALUE1"}
        secret = get_secret_or_env("KEY1", secret_provider=secrets)


        # Using a function to retrieve a secret
        def my_secret_provider(key):
            # some internal logic to retrieve secret
            return value


        secret = get_secret_or_env(
            "KEY1", secret_provider=my_secret_provider, default="TOO-MANY-SECRETS"
        )

    :param key: Secret key to look for
    :param secret_provider: Dictionary, callable or `SecretsStore` to extract the secret value from. If using a
        callable, it must use the signature `callable(key:str)`
    :param default: Default value to return if secret was not available through any other means
    :param prefix: When passed, the prefix is added to the secret key.
    :return: The secret value if found in any of the sources, or `default` if provided.
    """
    if prefix:
        key = f"{prefix}_{key}"

    if secret_provider:
        if isinstance(secret_provider, dict | SecretsStore):
            secret_value = secret_provider.get(key)
        else:
            secret_value = secret_provider(key)
        if secret_value:
            return secret_value

    direct_environment_value = environ.get(key)
    if direct_environment_value:
        return direct_environment_value

    json_list_value = _find_value_in_json_env_lists(key)
    if json_list_value is not None:
        return json_list_value

    mlrun_env_key = SecretsStore.k8s_env_variable_name_for_secret(key)
    mlrun_env_value = environ.get(mlrun_env_key)
    if mlrun_env_value:
        return mlrun_env_value

    return default


def _find_value_in_json_env_lists(
    secret_name: str,
) -> Optional[str]:
    """
    Scan all environment variables. If any env var contains a JSON-encoded list
    of dicts shaped like {'name': str, 'value': str|None, 'value_from': ...},
    return the 'value' for the entry whose 'name' matches secret_name.
    """
    for environment_variable_value in environ.values():
        if not environment_variable_value or not isinstance(
            environment_variable_value, str
        ):
            continue
        # Fast precheck to skip obvious non-JSON strings
        first_char = environment_variable_value.lstrip()[:1]
        if first_char not in ("[", "{"):
            continue
        try:
            parsed_value = json.loads(environment_variable_value)
        except ValueError:
            continue
        if isinstance(parsed_value, list):
            for entry in parsed_value:
                if isinstance(entry, dict) and entry.get("name") == secret_name:
                    value_in_entry = entry.get("value")
                    # Match original semantics: empty string is treated as "not found"
                    if value_in_entry:
                        return value_in_entry
    return None


@mlrun.utils.iguazio_v4_only
def sync_secret_tokens() -> None:
    """
    Synchronize local secret tokens with the backend.

    This function:
      1. Reads the local token file (default: ~/.igz.yml, configurable via
         `mlrun.mlconf.auth_with_oauth_token.token_file`).
      2. Validates its content and converts validated tokens into `SecretToken` objects.
      3. Uploads the tokens to the backend.
      4. Logs a warning if any tokens were updated on the backend due to newer
         expiration times found locally.
    """
    # TODO: Runtime Context Check - Avoid sending a backend request when running inside a runtime, where secrets
    #  are already injected via Kubernetes and syncing is unnecessary

    # Do not sync tokens from the file when using the offline token environment variable.
    # The offline token from the env var takes precedence over the file.
    # Using the env var is not the recommended approach, and tokens from the env var
    # will not be saved as secrets in the backend.
    if os.getenv("MLRUN_AUTH_OFFLINE_TOKEN"):
        return

    secret_tokens = mlrun.auth.utils.load_and_prepare_secret_tokens()

    # The import is needed here to prevent a circular import, since this method is called from the mlrun.db connection.
    from mlrun.db import get_run_db

    # The log_warning=False flag ensures the SDK doesn’t log unnecessary warnings about local file updates, since
    # this method reads from the file, not updates it.
    response = get_run_db().store_secret_tokens(secret_tokens, log_warning=False)

    if response.updated_tokens:
        logger.warning(
            "Some tokens were updated on the backend due to newer expiration found locally",
            updated_tokens=response.updated_tokens,
        )
